9022. Подсчитайте
тройки
Заданы три массива a, b
и c, каждый из которых состоит из
n целых чисел. Найдите количество
троек (ai, bj, ck), для которых выполняется неравенство ai < bj < ck.
Вход. Первая
строка содержит размер массивов n (n ≤ 105).
Вторая
строка содержит элементы массива a.
Третья
строка содержит элементы массива b.
Четвертая
строка содержит элементы массива c.
Выход. Выведите количество троек (ai, bj, ck),
удовлетворяющих условию ai
< bj < ck.
Пояснение. В первом тесте искомыми
тройками будут (a1, b1, c1), (a1, b2, c1) и (a1, b2, c2).
Пример
входа 1 |
Пример
выхода 1 |
2 1 5 4 2 6 3 |
3 |
|
|
Пример
входа 2 |
Пример
выхода 2 |
3 1 1 1 2 2 2 3 3 3 |
27 |
бинарный
поиск
Отсортируем
все три массива. Для каждого элемента bj при
помощи бинарного поиска определим:
·
количество элементов x в массиве a,
которые меньше bj,
·
количество элементов y в массиве c,
которые больше bj.
Тогда для фиксированного значения bj существует ровно x * y троек вида (ai, bj, ck), удовлетворяющих неравенству ai
< bj < ck.
Пример
Рассмотрим
отсортированные массивы и вычислим количество подходящих троек для b5 = 10.
Имеем: ai < b5 при i ≤ 5, а ck > b5 при k ≥ 7.
Таким
образом, неравенство ai < b5 < ck выполняется для 1 ≤ i ≤ 5 и 7 ≤ k ≤ 8.
Количество
троек (ai, b5, ck) равно 5 * 2 = 10.
Объявим рабочие массивы.
#define MAX
100000
int a[MAX], b[MAX], c[MAX];
Читаем входные массивы.
scanf("%d", &n);
for (i = 0; i < n; i++) scanf("%d",
&a[i]);
for (i = 0; i < n; i++) scanf("%d",
&b[i]);
for (i = 0; i < n; i++) scanf("%d",
&c[i]);
Сортируем массивы.
sort(a, a + n);
sort(b, b + n);
sort(c, c + n);
Количество искомых троек будем подсчитывать
в переменной res. Перебираем значения bj.
res = 0;
for (j = 0; j < n; j++)
{
Количество элементов массива a, меньших
bj, равно x.
x =
lower_bound(a, a + n, b[j]) - a;
Количество элементов массива c, больших
bj, равно y.
y = n
- (upper_bound(c, c + n, b[j]) - c);
Тогда для значения bj существует ровно x *
y искомых троек.
res
+= x * y;
}
Выводим ответ.
printf("%lld\n",
res);
import java.util.*;
public class Main
{
static int
lower_bound(int m[], int start, int end, int x)
{
while (start <
end)
{
int mid = (start + end) /
2;
if (x
<= m[mid])
end = mid;
else
start = mid + 1;
}
return start;
}
static int
upper_bound(int m[], int start, int end, int x)
{
while (start <
end)
{
int mid = (start + end) /
2;
if (x
>= m[mid])
start = mid + 1;
else
end = mid;
}
return start;
}
public static void
main(String[] args)
{
Scanner con = new
Scanner(System.in);
int i, n = con.nextInt();
int a[] = new int[n];
for(i = 0;
i < n; i++) a[i] = con.nextInt();
int b[] = new int[n];
for(i = 0;
i < n; i++) b[i] = con.nextInt();
int c[] = new int[n];
for(i = 0;
i < n; i++) c[i] = con.nextInt();
Arrays.sort(a);
Arrays.sort(b); Arrays.sort(c);
long res = 0;
for (i = 0;
i < n; i++)
{
int x = lower_bound(a, 0, n, b[i]);
int y = n - (upper_bound(c, 0, n, b[i]));
res +=
1L * x * y;
}
System.out.println(res);
con.close();
}
}